# This file includes code originally from the Pytorch Correlation repository:
# https://github.com/ClementPinard/Pytorch-Correlation-extension
# Licensed under the MIT License. See THIRD_PARTY_LICENSES.md for details.

from torch import nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair

from vipe.ext import corr_ext as correlation


def spatial_correlation_sample(
    input1, input2, kernel_size=1, patch_size=1, stride=1, padding=0, dilation=1, dilation_patch=1
):
    """Apply spatial correlation sampling on from input1 to input2,

    Every parameter except input1 and input2 can be either single int
    or a pair of int. For more information about Spatial Correlation
    Sampling, see this page.
    https://lmb.informatik.uni-freiburg.de/Publications/2015/DFIB15/

    Args:
        input1 : The first parameter.
        input2 : The second parameter.
        kernel_size : total size of your correlation kernel, in pixels
        patch_size : total size of your patch, determining how many
            different shifts will be applied
        stride : stride of the spatial sampler, will modify output
            height and width
        padding : padding applied to input1 and input2 before applying
            the correlation sampling, will modify output height and width
        dilation_patch : step for every shift in patch

    Returns:
        Tensor: Result of correlation sampling

    """
    return SpatialCorrelationSamplerFunction.apply(
        input1, input2, kernel_size, patch_size, stride, padding, dilation, dilation_patch
    )


class SpatialCorrelationSamplerFunction(Function):
    @staticmethod
    def forward(ctx, input1, input2, kernel_size=1, patch_size=1, stride=1, padding=0, dilation=1, dilation_patch=1):
        ctx.save_for_backward(input1, input2)
        kH, kW = ctx.kernel_size = _pair(kernel_size)
        patchH, patchW = ctx.patch_size = _pair(patch_size)
        padH, padW = ctx.padding = _pair(padding)
        dilationH, dilationW = ctx.dilation = _pair(dilation)
        dilation_patchH, dilation_patchW = ctx.dilation_patch = _pair(dilation_patch)
        dH, dW = ctx.stride = _pair(stride)

        output = correlation.forward(
            input1,
            input2,
            kH,
            kW,
            patchH,
            patchW,
            padH,
            padW,
            dilationH,
            dilationW,
            dilation_patchH,
            dilation_patchW,
            dH,
            dW,
        )

        return output

    @staticmethod
    @once_differentiable
    def backward(ctx, grad_output):
        input1, input2 = ctx.saved_tensors

        kH, kW = ctx.kernel_size
        patchH, patchW = ctx.patch_size
        padH, padW = ctx.padding
        dilationH, dilationW = ctx.dilation
        dilation_patchH, dilation_patchW = ctx.dilation_patch
        dH, dW = ctx.stride

        grad_input1, grad_input2 = correlation.backward(
            input1,
            input2,
            grad_output,
            kH,
            kW,
            patchH,
            patchW,
            padH,
            padW,
            dilationH,
            dilationW,
            dilation_patchH,
            dilation_patchW,
            dH,
            dW,
        )
        return grad_input1, grad_input2, None, None, None, None, None, None


class SpatialCorrelationSampler(nn.Module):
    def __init__(self, kernel_size=1, patch_size=1, stride=1, padding=0, dilation=1, dilation_patch=1):
        super(SpatialCorrelationSampler, self).__init__()
        self.kernel_size = kernel_size
        self.patch_size = patch_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.dilation_patch = dilation_patch

    def forward(self, input1, input2):
        return SpatialCorrelationSamplerFunction.apply(
            input1,
            input2,
            self.kernel_size,
            self.patch_size,
            self.stride,
            self.padding,
            self.dilation,
            self.dilation_patch,
        )
